[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912tdophung wants to merge 14 commits into
Conversation
Greptile SummaryThis PR introduces
Confidence Score: 4/5Safe to merge after addressing the missing ep_axis exclusion guard in the data_parallelism_axes validation loop. The A2A-EP forward correctly addresses the recv_buffer_rows alignment fix. One gap remains: if a caller passes the EP axis name in data_parallelism_axes, the batch PartitionSpec gets a duplicate axis and dp_size is double-counted, producing an undersized ragged_all_to_all receive buffer with no useful error message. transformer_engine/jax/flax/moe.py — specifically the data_parallelism_axes validation block in _forward_a2a_ep. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant _MoEBlock
participant Router
participant GlobalPermute
participant A2A as ragged_all_to_all (EP)
participant LocalPerm as local_permute_after_a2a
participant ExpertFFN as _expert_ffn (grouped_dense x3)
participant GlobalCombine
Caller->>_MoEBlock: inputs [B, S, H]
_MoEBlock->>Router: "gate_logits -> fused_topk_with_score_function"
Router-->>_MoEBlock: sparse_probs, routing_map
_MoEBlock->>GlobalPermute: _global_permute (pure_jax or triton)
GlobalPermute-->>_MoEBlock: sorted_inputs, group_sizes [E]
alt No-EP path
_MoEBlock->>ExpertFFN: "sorted_inputs, group_sizes, n_groups=E"
ExpertFFN-->>_MoEBlock: expert_outputs
else A2A-EP path via shard_map
_MoEBlock->>A2A: all_gather(group_sizes)
A2A->>A2A: forward ragged_all_to_all over ep axis
A2A->>LocalPerm: reorder recv buffer
LocalPerm-->>A2A: sorted_x, local_group_sizes
A2A->>ExpertFFN: sorted_x, local_group_sizes
ExpertFFN-->>A2A: expert_outputs
A2A->>LocalPerm: local_unpermute_before_a2a
A2A->>A2A: reverse ragged_all_to_all
A2A-->>_MoEBlock: y_back
end
_MoEBlock->>GlobalCombine: _global_combine
GlobalCombine-->>_MoEBlock: output [B, S, H]
_MoEBlock-->>Caller: output [B, S, H], aux_loss
Reviews (6): Last reviewed commit: "change naming and add message for experi..." | Re-trigger Greptile |
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
| def _compute_aux_loss( | ||
| self, | ||
| logits_2d: jnp.ndarray, | ||
| ) -> Optional[jnp.ndarray]: | ||
| """Compute the MoE auxiliary load-balancing loss. | ||
|
|
||
| The score-for-aux kernel has no data dependency on the main | ||
| routing kernel, so XLA can overlap them on the GPU. | ||
|
|
||
| ``logits_2d`` should be the *full* logits tensor over the global | ||
| token batch -- under EP the caller is responsible for | ||
| :func:`jax.lax.all_gather` ing the logits before calling this so | ||
| the aux_loss formula | ||
| ``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])`` | ||
| sees the global ``T`` and the global ``tokens_per_expert``. | ||
| """ | ||
| if self.aux_loss_coeff <= 0.0: | ||
| return None | ||
| aux_scores, aux_routing_map = fused_topk_with_score_function( | ||
| logits_2d.astype(jnp.float32), | ||
| topk=self.num_experts_per_tok, | ||
| score_function=self.score_function, | ||
| compute_aux_scores=True, | ||
| ) | ||
| aux_tokens_per_expert = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0) | ||
| return fused_moe_aux_loss( | ||
| aux_scores.astype(jnp.float32), | ||
| aux_tokens_per_expert, | ||
| topk=self.num_experts_per_tok, | ||
| coeff=self.aux_loss_coeff, | ||
| ) |
There was a problem hiding this comment.
Aux loss
tokens_per_expert is inconsistent with actual grouped-topk routing
When num_groups > 0 and group_topk > 0 (DeepSeek-style routing), fused_topk_with_score_function(..., compute_aux_scores=True) intentionally ignores those parameters and runs a clean standard top-k. The returned aux_routing_map therefore reflects different expert selections than the actual routing_map produced by _route_topk, causing aux_tokens_per_expert = sum(aux_routing_map, axis=0) to count a different token–expert distribution. Any user who combines num_groups > 0 + group_topk > 0 + aux_loss_coeff > 0 silently trains with a wrong auxiliary objective. The existing test_group_topk_deepseek test does not catch this because it leaves aux_loss_coeff at its default of 0.0.
…int in C++ files, make FP8 works. Tested with current scaling Signed-off-by: JAX Toolbox <jax@nvidia.com>
for more information, see https://pre-commit.ci
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
… grad tol to 5e-2, move arch/align_size docs into MoEBlock class Signed-off-by: tdophung <tdophung@nvidia.com>
| batch_divisor = num_ep * dp_size | ||
| if global_batch_size % batch_divisor != 0: | ||
| raise ValueError( | ||
| f"batch={global_batch_size} not divisible by prod(data_parallelism_axes)={dp_size}" | ||
| ) | ||
| recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk |
There was a problem hiding this comment.
Receive buffer undersized when
align_size > 0 + EP are combined
recv_buffer_rows is computed assuming unpadded token counts, but when align_size > 0 the per-expert group_sizes are the aligned counts, so the send_sizes in compute_ragged_all_to_all_params include padding tokens. The worst-case receive per shard is num_ep * ((B/(num_ep*dp_size))*S*K + num_experts_per_shard*(align_size-1)), which exceeds the current recv_buffer_rows = (B/dp_size)*S*K by up to num_experts*(align_size-1) rows. ragged_all_to_all writing beyond the buffer produces incorrect results or a crash. The correct worst-case size is:
recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk + num_experts * (self.align_size - 1 if self.align_size > 0 else 0)
This combination (EP + align_size > 0) is not exercised by the current distributed test (which defaults to align_size=0), so the bug is latent.
phu0ngng
left a comment
There was a problem hiding this comment.
I think we should go with exposing GroupMLP VJP first before the MoE module to enable future possible fusions.
…ing None as group_topk, align_size rename, Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
| for ax in self.data_parallelism_axes: | ||
| if ax not in mesh.shape: | ||
| raise ValueError( | ||
| f"data_parallelism_axes contains {ax!r} but mesh has" | ||
| f" axes {tuple(mesh.shape.keys())}" | ||
| ) |
There was a problem hiding this comment.
The validation loop checks that every axis in
data_parallelism_axes exists in the mesh but does not check that the axis differs from ep_axis. If a caller passes data_parallelism_axes=("ep",) when ep_axis="ep", batch_pspec_axis becomes ("ep", "ep") — a duplicate-axis PartitionSpec that JAX rejects with a cryptic error. Independently, dp_size accumulates mesh.shape["ep"] a second time, so recv_buffer_rows is undersized by a factor of num_ep and batch_divisor becomes num_ep², both causing wrong runtime behaviour before JAX ever sees the bad spec.
| for ax in self.data_parallelism_axes: | |
| if ax not in mesh.shape: | |
| raise ValueError( | |
| f"data_parallelism_axes contains {ax!r} but mesh has" | |
| f" axes {tuple(mesh.shape.keys())}" | |
| ) | |
| for ax in self.data_parallelism_axes: | |
| if ax not in mesh.shape: | |
| raise ValueError( | |
| f"data_parallelism_axes contains {ax!r} but mesh has" | |
| f" axes {tuple(mesh.shape.keys())}" | |
| ) | |
| if ax == ep_axis: | |
| raise ValueError( | |
| f"data_parallelism_axes contains {ax!r}, which is the same as the" | |
| f" EP axis {ep_axis!r}. The EP axis is already included in the batch" | |
| " sharding spec; listing it again produces a duplicate-axis" | |
| " PartitionSpec and an undersized ragged_all_to_all receive buffer." | |
| ) |
Description
Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation.
MoEBlockis a self-contained Flax-Linen module that wires together TE's fused router, pluggable token-dispatch backends (pure-JAX argsort or Tritonsort_chunks_by_index),grouped_dense-based expert FFN, and ragged-all-to-all (A2Av) expert parallelism viashard_mapThis first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM and 2 permutation backend: pure JAX or Triton kernels. The block also exposes pluggable knobs for: weight layout (
wi_kernel_axes/wo_kernel_axes), permutation backend, A2A vs no-EP (single GPU) path, data-parallelism axes for true FSDP (batch sharded across(ep, fsdp)simultaneously), top-K with optional grouped/sigmoid scoring (for DSv3 workload), and optional auxiliary load-balancing loss.Fixes #2895
Type of change
Changes
transformer_engine/jax/flax/moe.py--MoEBlockLinen module:gate -> fused topk -> global permute -> A2A EP shard_map (ragged_a2a fwd, local permute, 3x grouped GEMM SwiGLU FFN, local unpermute, ragged_a2a rev) -> global combine.
transformer_engine/jax/permutation.pywith A2A param helpers (compute_ragged_all_to_all_params,compute_reverse_ragged_all_to_all_params,local_permute_after_a2a,local_unpermute_before_a2a) and the pure-JAXunfused_token_dispatch/unfused_token_combinepathswith custom VJPs.
tests/jax/test_moe_block.py-- single-device shape, backward,cross-backend equivalence, aux-loss, group-topk, JIT determinism.
tests/jax/test_distributed_moe_block.py-- EP=2 x FSDP=2 mesh test using the canonical Flax-Linen sharded-init pattern (eval_shape -> get_partition_spec -> logical_to_mesh_sharding -> jit(init, out_shardings=...)) anddata_parallelism_axes=("fsdp",)to exercise true FSDP (batch sharded across both axes).Checklist: